{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install peft"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7fcbbb7cd3f0>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from peft import LoraConfig, PeftModel\n",
    "\n",
    "\n",
    "torch.manual_seed(12046)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Lora(nn.Module):\n",
    "    \n",
    "    def __init__(self, model, r=4, lora_alpha=16):\n",
    "        '''\n",
    "        Add LoRA adapter to linear model\n",
    "        \n",
    "        Args:\n",
    "        ----\n",
    "        model: Linear model\n",
    "        r: int, The rank of LoRA\n",
    "        lora_alpha: int, Alpha in LoRA algo\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.model = model\n",
    "        # Freeze model\n",
    "        self._freezing_model()\n",
    "        self.lora_A = nn.Linear(model.in_features, r, bias=False)\n",
    "        self.lora_B = nn.Linear(r, model.out_features, bias=False)\n",
    "        # Define the scaling factor of LoRA\n",
    "        self.scaling = lora_alpha / r\n",
    "        \n",
    "    def _freezing_model(self):\n",
    "        for p in self.model.parameters():\n",
    "            p.requires_grad = False\n",
    "        \n",
    "    def forward(self, x):\n",
    "        origin = self.model(x)\n",
    "        delta = self.lora_B(self.lora_A(x)) * self.scaling\n",
    "        return origin + delta"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(True)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def _test_lora(model, r=4, lora_alpha=16):\n",
    "    '''\n",
    "    Test LoRA\n",
    "    '''\n",
    "    lora_model = Lora(model, r, lora_alpha)\n",
    "    # Define the reference model\n",
    "    _model = nn.ModuleDict({'lin': model})\n",
    "    config = LoraConfig(\n",
    "        r=r, lora_alpha=lora_alpha,\n",
    "        target_modules=['lin'],\n",
    "        # For test purpose, we randomly initialise LoRA\n",
    "        # In general, we do NOT change the following parameter\n",
    "        init_lora_weights=False)\n",
    "    peft_model = PeftModel(_model, config)\n",
    "    lin = peft_model.base_model.model.lin\n",
    "    # Copy the parameter of LoRA\n",
    "    lora_model.lora_A.weight.data = lin.lora_A.default.weight.clone()\n",
    "    lora_model.lora_B.weight.data = lin.lora_B.default.weight.clone()\n",
    "    x = torch.randn(10, model.in_features)\n",
    "    return torch.all(torch.abs(lora_model(x) - lin(x)) < 1e-3)\n",
    "\n",
    "\n",
    "linear_model = nn.Linear(10, 20)\n",
    "_test_lora(linear_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def print_trainable_parameters(model):\n",
    "    \"\"\"\n",
    "    Print the number of trainable parameters\n",
    "    \"\"\"\n",
    "    trainable_params = 0\n",
    "    all_param = 0\n",
    "    for _, param in model.named_parameters():\n",
    "        all_param += param.numel()\n",
    "        if param.requires_grad:\n",
    "            trainable_params += param.numel()\n",
    "    trainable = f'trainable params: {trainable_params:,}'\n",
    "    params = f'all params: {all_param:,}'\n",
    "    percent = f'trainable%: {100 * trainable_params / all_param:.3f}'\n",
    "    print(f'{trainable} || {params} || {percent}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trainable params: 66 || all params: 66 || trainable%: 100.000\n",
      "tensor([[-1.3469e-03, -3.9655e-01,  5.7396e-01,  3.1267e-01, -1.6206e+00,\n",
      "         -1.1444e-01],\n",
      "        [ 1.4498e-01,  2.1979e-02,  7.9094e-01,  5.0265e-01, -1.9905e-01,\n",
      "         -2.0630e-01],\n",
      "        [ 6.7352e-01, -5.4856e-01, -1.0576e-01, -1.1910e+00,  6.0106e-01,\n",
      "         -3.6762e-01]], grad_fn=<AddmmBackward0>)\n",
      "trainable params: 64 || all params: 130 || trainable%: 49.231\n",
      "tensor([[-1.2706, -0.6598, -1.7895, -2.0696, -0.8138,  0.5505],\n",
      "        [-0.1612,  0.5151,  1.3028, -0.0054, -0.1200, -1.4266],\n",
      "        [ 0.9780, -1.2645, -0.0722,  3.4606,  0.6124, -0.9581]],\n",
      "       grad_fn=<AddBackward0>)\n"
     ]
    }
   ],
   "source": [
    "# Illustrate the effect of LoRA\n",
    "linear_model = nn.Linear(10, 6)\n",
    "x = torch.randn(3, 10)\n",
    "# Linear model\n",
    "print_trainable_parameters(linear_model)\n",
    "print(linear_model(x))\n",
    "# Add LoRA adapter to linear model\n",
    "lora_model = Lora(linear_model)\n",
    "print_trainable_parameters(lora_model)\n",
    "print(lora_model(x))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using MLP to show the usage of LoRA\n",
    "class MLP(nn.Module):\n",
    "    \n",
    "    def __init__(self, bias=False):\n",
    "        '''\n",
    "        多层感知器\n",
    "        '''\n",
    "        super().__init__()\n",
    "        self.lin0 = nn.Linear(2, 4, bias=bias)\n",
    "        self.lin1 = nn.Linear(4, 2, bias=bias)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = F.relu(self.lin0(x))\n",
    "        x = self.lin1(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "MLP(\n",
       "  (lin0): Linear(in_features=2, out_features=4, bias=False)\n",
       "  (lin1): Linear(in_features=4, out_features=2, bias=False)\n",
       ")"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = MLP()\n",
    "x = torch.randn(2)\n",
    "model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.2006,  0.0176], grad_fn=<SqueezeBackward4>)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# The output of normal MLP\n",
    "origin_re = model(x)\n",
    "origin_re"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PeftModel(\n",
       "  (base_model): LoraModel(\n",
       "    (model): MLP(\n",
       "      (lin0): Linear(\n",
       "        in_features=2, out_features=4, bias=False\n",
       "        (lora_dropout): ModuleDict(\n",
       "          (lora1): Identity()\n",
       "        )\n",
       "        (lora_A): ModuleDict(\n",
       "          (lora1): Linear(in_features=2, out_features=2, bias=False)\n",
       "        )\n",
       "        (lora_B): ModuleDict(\n",
       "          (lora1): Linear(in_features=2, out_features=4, bias=False)\n",
       "        )\n",
       "        (lora_embedding_A): ParameterDict()\n",
       "        (lora_embedding_B): ParameterDict()\n",
       "      )\n",
       "      (lin1): Linear(in_features=4, out_features=2, bias=False)\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "config = LoraConfig(\n",
    "    r=2,\n",
    "    lora_alpha=16,\n",
    "    target_modules=['lin0'],\n",
    "    # For test purpose, we randomly initialise LoRA\n",
    "    # In general, we do NOT change the following parameter\n",
    "    init_lora_weights=False)\n",
    "\n",
    "# Add LoRA adapter to model\n",
    "peft_model = PeftModel(model, config, adapter_name='lora1')\n",
    "peft_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([-0.2006,  0.0176], grad_fn=<SqueezeBackward4>),\n",
       " tensor([-0.3068,  0.1873], grad_fn=<SqueezeBackward4>),\n",
       " tensor([-0.3068,  0.1873], grad_fn=<SqueezeBackward4>))"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# After adding LoRA adapter, the original model has changed too\n",
    "origin_re, peft_model(x), model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([-0.2006,  0.0176])\n",
      "tensor([-0.2006,  0.0176], grad_fn=<SqueezeBackward4>)\n"
     ]
    }
   ],
   "source": [
    "# Disable LoRA to get original model\n",
    "with peft_model.disable_adapter():\n",
    "    print(peft_model(x))\n",
    "print(origin_re)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([-0.2006,  0.0176], grad_fn=<SqueezeBackward4>),\n",
       " tensor([-0.2006,  0.0176]),\n",
       " tensor([-0.2006,  0.0176]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Unload LoRA, model returns to original state\n",
    "peft_model.unload()\n",
    "origin_re, peft_model(x), model(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PeftModel(\n",
       "  (base_model): LoraModel(\n",
       "    (model): MLP(\n",
       "      (lin0): Linear(\n",
       "        in_features=2, out_features=4, bias=False\n",
       "        (lora_dropout): ModuleDict(\n",
       "          (lora1): Identity()\n",
       "          (lora2): Identity()\n",
       "        )\n",
       "        (lora_A): ModuleDict(\n",
       "          (lora1): Linear(in_features=2, out_features=3, bias=False)\n",
       "          (lora2): Linear(in_features=2, out_features=5, bias=False)\n",
       "        )\n",
       "        (lora_B): ModuleDict(\n",
       "          (lora1): Linear(in_features=3, out_features=4, bias=False)\n",
       "          (lora2): Linear(in_features=5, out_features=4, bias=False)\n",
       "        )\n",
       "        (lora_embedding_A): ParameterDict()\n",
       "        (lora_embedding_B): ParameterDict()\n",
       "      )\n",
       "      (lin1): Linear(\n",
       "        in_features=4, out_features=2, bias=False\n",
       "        (lora_dropout): ModuleDict(\n",
       "          (lora2): Identity()\n",
       "        )\n",
       "        (lora_A): ModuleDict(\n",
       "          (lora2): Linear(in_features=4, out_features=5, bias=False)\n",
       "        )\n",
       "        (lora_B): ModuleDict(\n",
       "          (lora2): Linear(in_features=5, out_features=2, bias=False)\n",
       "        )\n",
       "        (lora_embedding_A): ParameterDict()\n",
       "        (lora_embedding_B): ParameterDict()\n",
       "      )\n",
       "    )\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Add multiple LoRA adapters to model\n",
    "config1 = LoraConfig(r=3, lora_alpha=16, target_modules=['lin0'])\n",
    "config2 = LoraConfig(r=5, lora_alpha=16, target_modules=['lin0', 'lin1'])\n",
    "\n",
    "model = MLP()\n",
    "peft_model = PeftModel(model, config1, adapter_name='lora1')\n",
    "peft_model.add_adapter(peft_config=config2, adapter_name='lora2')\n",
    "peft_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "False\n",
      "True\n",
      "True\n"
     ]
    }
   ],
   "source": [
    "# The parameters of original model can NOT be trained\n",
    "print(peft_model.base_model.model.lin1.weight.requires_grad)\n",
    "# The parameters of two LoRA adapters can be trained\n",
    "print(peft_model.base_model.model.lin0.lora_B.lora2.weight.requires_grad)\n",
    "print(peft_model.base_model.model.lin0.lora_B.lora1.weight.requires_grad)\n",
    "optimizer = optim.SGD(peft_model.parameters(), lr=0.1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "active adapter: lora1\n",
      "before bp, lora1: None\n",
      "before bp, lora2: None\n",
      "after bp, lora1: tensor([[ 0.0000, -0.0000,  0.0000],\n",
      "        [ 0.0995, -0.6875,  0.6329],\n",
      "        [-1.0511,  7.2626, -6.6856],\n",
      "        [ 0.0000, -0.0000,  0.0000]])\n",
      "after bp, lora2: None\n"
     ]
    }
   ],
   "source": [
    "# Use one adapter\n",
    "optimizer.zero_grad()\n",
    "print(f'active adapter: {peft_model.active_adapter}')\n",
    "print(f'before bp, lora1: {peft_model.base_model.model.lin0.lora_B.lora1.weight.grad}')\n",
    "print(f'before bp, lora2: {peft_model.base_model.model.lin0.lora_B.lora2.weight.grad}')\n",
    "peft_model(x).sum().backward()\n",
    "# When the adapter is active, we will compute the gradient\n",
    "print(f'after bp, lora1: {peft_model.base_model.model.lin0.lora_B.lora1.weight.grad}')\n",
    "print(f'after bp, lora2: {peft_model.base_model.model.lin0.lora_B.lora2.weight.grad}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "active adapter: lora2\n",
      "before bp, lora1: None\n",
      "before bp, lora2: None\n",
      "after bp, lora1: None\n",
      "after bp, lora2: tensor([[ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000],\n",
      "        [ 0.0056, -0.3897, -0.0461,  0.2852, -0.0939],\n",
      "        [-0.0597,  4.1171,  0.4875, -3.0128,  0.9917],\n",
      "        [ 0.0000, -0.0000, -0.0000,  0.0000, -0.0000]])\n"
     ]
    }
   ],
   "source": [
    "# Switch to another adapter\n",
    "peft_model.set_adapter('lora2')\n",
    "optimizer.zero_grad()\n",
    "print(f'active adapter: {peft_model.active_adapter}')\n",
    "print(f'before bp, lora1: {peft_model.base_model.model.lin0.lora_A.lora1.weight.grad}')\n",
    "print(f'before bp, lora2: {peft_model.base_model.model.lin0.lora_A.lora2.weight.grad}')\n",
    "peft_model(x).sum().backward()\n",
    "print(f'after bp, lora1: {peft_model.base_model.model.lin0.lora_B.lora1.weight.grad}')\n",
    "print(f'after bp, lora2: {peft_model.base_model.model.lin0.lora_B.lora2.weight.grad}')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
